use dftd3::prelude::*;

fn main_test() {
    // atom indices
    let numbers = vec![1, 6, 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 35, 6, 9, 9, 9];
    // geometry in angstrom
    #[rustfmt::skip]
    let positions = vec![
        0.002144194,   0.361043475,   0.029799709,
        0.015020592,   0.274789738,   1.107648016,
        1.227632658,   0.296655040,   1.794629427,
        1.243958826,   0.183702791,   3.183703934,
        0.047958213,   0.048915002,   3.886484583,
       -1.165135654,   0.026954348,   3.200213281,
       -1.181832083,   0.139828643,   1.810376587,
        2.155807907,   0.399177037,   1.249441585,
        2.184979344,   0.198598553,   3.716170761,
        0.060934662,  -0.040672756,   4.964014252,
       -2.093220602,  -0.078628959,   3.745125056,
       -2.122845437,   0.123257119,   1.277645797,
       -0.268325907,  -3.194209024,   1.994458950,
        0.049999933,  -5.089197474,   1.929391171,
        0.078949601,  -5.512441335,   0.671851563,
        1.211983937,  -5.383996300,   2.498664481,
       -0.909987405,  -5.743747328,   2.570721738,
    ];
    // convert angstrom to bohr
    let positions = positions.iter().map(|&x| x / 0.52917721067).collect::<Vec<f64>>();
    // generate DFTD3 model
    let model = DFTD3Model::new(&numbers, &positions, None, None);
    // retrive the DFTD3 parameters
    let param = dftd3_load_param("d3bj", "r2SCAN", false);
    // obtain the dispersion energy and gradient, without sigma
    let (energy, gradient, _) = model.get_dispersion(&param, true).into();
    let gradient = gradient.unwrap();
    println!("Dispersion energy: {}", energy);
    println!("Dispersion gradient:");
    gradient.chunks(3).for_each(|chunk| println!("{:16.9?}", chunk));

    #[rustfmt::skip]
    let gradient_ref = vec![
         7.13721248e-07,  2.19571763e-05, -3.77372946e-05,
         9.19838860e-07,  3.53459763e-05, -1.43306994e-06,
         7.43860881e-06,  3.78237447e-05,  8.46031238e-07,
         8.06120927e-06,  3.79834948e-05,  8.58427570e-06,
         1.16592466e-06,  3.62585085e-05,  1.16326308e-05,
        -3.69381337e-06,  3.39047971e-05,  6.92483428e-06,
        -3.05404225e-06,  3.29484247e-05,  1.80766271e-06,
         3.51228183e-05,  2.08136972e-05, -1.76546837e-05,
         3.49762054e-05,  1.66544908e-05,  2.14435772e-05,
         1.57516340e-06,  1.41373959e-05,  4.21574793e-05,
        -3.35392428e-05,  1.49030766e-05,  2.29976305e-05,
        -3.38817253e-05,  1.82002569e-05, -1.72487448e-05,
        -2.15610724e-05, -1.87935101e-04, -3.02815495e-05,
         1.27580963e-06, -5.96841724e-05, -5.99713166e-06,
         9.01173808e-07, -2.23010304e-05, -7.96228701e-06,
         7.42062176e-06, -2.79631452e-05,  7.03703317e-07,
        -3.84119900e-06, -2.30475903e-05,  1.21693625e-06,
    ];

    let l2_diff =
        gradient.iter().zip(gradient_ref.iter()).map(|(x, y)| (x - y).powi(2)).sum::<f64>().sqrt();
    println!("L2 difference: {:16.8e}", l2_diff);
    assert!(l2_diff < 1e-9);
}

#[test]
fn test() {
    main_test();
}

fn main() {
    main_test();
}

/* equivalent PySCF with dftd3

// https://github.com/dftd3/simple-dftd3/blob/v1.2.1/python/dftd3/test_pyscf.py#L29-L56

```python
import pyscf
from pyscf import gto
import dftd3.pyscf as disp

mol = gto.M(atom="""
    H    0.002144194   0.361043475   0.029799709
    C    0.015020592   0.274789738   1.107648016
    C    1.227632658   0.296655040   1.794629427
    C    1.243958826   0.183702791   3.183703934
    C    0.047958213   0.048915002   3.886484583
    C   -1.165135654   0.026954348   3.200213281
    C   -1.181832083   0.139828643   1.810376587
    H    2.155807907   0.399177037   1.249441585
    H    2.184979344   0.198598553   3.716170761
    H    0.060934662  -0.040672756   4.964014252
    H   -2.093220602  -0.078628959   3.745125056
    H   -2.122845437   0.123257119   1.277645797
    Br  -0.268325907  -3.194209024   1.994458950
    C    0.049999933  -5.089197474   1.929391171
    F    0.078949601  -5.512441335   0.671851563
    F    1.211983937  -5.383996300   2.498664481
    F   -0.909987405  -5.743747328   2.570721738
""")

d3 = disp.DFTD3Dispersion(mol, xc="r2SCAN")
print(d3.kernel()[1])
```

*/
